{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import gluon, nd, autograd\n",
    "from mxnet.gluon import nn\n",
    "import numpy as np\n",
    "import pickle as p\n",
    "import matplotlib.pyplot as plt\n",
    "from time import time\n",
    "%matplotlib inline\n",
    "ctx = mx.gpu()\n",
    "data_dir = '/home/sinyer/python/data/cifar10'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_cifar(route = data_dir+'/cifar-10-batches-py'):\n",
    "    def load_batch(filename):\n",
    "        with open(filename, 'rb')as f:\n",
    "            data_dict = p.load(f, encoding='latin1')\n",
    "            X = data_dict['data']\n",
    "            Y = data_dict['labels']\n",
    "            X = X.reshape(10000, 3, 32,32).astype(\"float\")\n",
    "            Y = np.array(Y)\n",
    "            return X, Y\n",
    "    def load_labels(filename):\n",
    "        with open(filename, 'rb') as f:\n",
    "            label_names = p.load(f, encoding='latin1')\n",
    "            names = label_names['label_names']\n",
    "            return names\n",
    "    label_names = load_labels(route + \"/batches.meta\")\n",
    "    x1, y1 = load_batch(route + \"/data_batch_1\")\n",
    "    x2, y2 = load_batch(route + \"/data_batch_2\")\n",
    "    x3, y3 = load_batch(route + \"/data_batch_3\")\n",
    "    x4, y4 = load_batch(route + \"/data_batch_4\")\n",
    "    x5, y5 = load_batch(route + \"/data_batch_5\")\n",
    "    test_pic, test_label = load_batch(route + \"/test_batch\")\n",
    "    train_pic = np.concatenate((x1, x2, x3, x4, x5))\n",
    "    train_label = np.concatenate((y1, y2, y3, y4, y5))\n",
    "    return train_pic, train_label, test_pic, test_label\n",
    "\n",
    "def accuracy(output, label):\n",
    "    return nd.mean(output.argmax(axis=1)==label).asscalar()\n",
    "\n",
    "def evaluate_accuracy(test_data, net, ctx):\n",
    "    acc = 0.\n",
    "    for data, label in test_data:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        output = net(data)\n",
    "        acc = acc + accuracy(output, label)\n",
    "    return acc / len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_pic, train_label, test_pic, test_label = load_cifar()\n",
    "\n",
    "batch_size = 128\n",
    "train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    train_pic.astype('float32')/255, train_label.astype('float32')), batch_size, shuffle=True)\n",
    "test_data = gluon.data.DataLoader(gluon.data.ArrayDataset(\n",
    "    test_pic.astype('float32')/255, test_label.astype('float32')), batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = gluon.nn.Sequential()\n",
    "with net.name_scope():\n",
    "    net.add(\n",
    "        nn.Conv2D(channels=20, kernel_size=5, activation='relu'),\n",
    "        nn.MaxPool2D(pool_size=2, strides=2),\n",
    "        nn.Conv2D(channels=50, kernel_size=3, activation='relu'),\n",
    "        nn.MaxPool2D(pool_size=2, strides=2),\n",
    "        nn.Flatten(),\n",
    "        nn.Dense(128, activation=\"relu\"),\n",
    "        nn.Dense(10)\n",
    "    )\n",
    "net.initialize(ctx=ctx)\n",
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()\n",
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.02, 'momentum': 0.9, 'wd': 5e-4})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 loss:1.8597 tracc:0.3219 teacc:0.4565 time:18.192\n",
      "10 loss:0.5632 tracc:0.8010 teacc:0.6735 time:1.861\n",
      "20 loss:0.2029 tracc:0.9293 teacc:0.6563 time:1.865\n",
      "30 loss:0.1151 tracc:0.9622 teacc:0.6613 time:1.868\n",
      "40 loss:0.0228 tracc:0.9964 teacc:0.6872 time:1.865\n",
      "50 loss:0.0079 tracc:1.0000 teacc:0.6937 time:1.870\n",
      "60 loss:0.0095 tracc:1.0000 teacc:0.6928 time:1.869\n",
      "70 loss:0.0099 tracc:1.0000 teacc:0.6932 time:1.874\n",
      "tracc:1.000000 teacc:0.693532\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzt3Xl8XXWd//HXJzdLm6VpmqT7XrpA\nC10IpQtLWZSKI+hjXFoXEBhxVEYZHBF0hhl1fjoOOIw6jIogIiqLBaV0CgVZRIVC05bu+54uSZq2\nWZv1fn9/fG/aS5o0tyXJuffm/Xw8ziP33Pu9537ukvc593vO+V5zziEiIsklJegCRESk6yncRUSS\nkMJdRCQJKdxFRJKQwl1EJAkp3EVEklCn4W5mvzCzMjNb38HtZmY/MrPtZrbWzGZ0fZkiInImYtly\n/yUw/zS3fwAYH5luBX7y3ssSEZH3otNwd869Dhw5TZPrgV85bznQ38yGdFWBIiJy5lK7YBnDgH1R\n8yWR6w62bWhmt+K37snKyrpw0qRJXfDwIl2voTlMWXU9xxtbaGgOB12OJJlh/fsyICv9rO67cuXK\nw865ws7adUW4WzvXtTumgXPuQeBBgKKiIldcXNwFDy/StZ59Zz93P7OOgSnGRaMHMHloPyYP7cfY\nwmxSU4xQipFihhm0jt4Rdg7DXwdgBmYn/zWccyfmY/mHaa+9tXdHTtYQvay2Tdve17mTj9neY3V0\nW3vL6agu6Vi/PmlkZZxd/JrZnljadUW4lwAjouaHAwe6YLkiPaqhuYV/X7KJx5bvoWhUHv/zyRkM\nzu0TdFkiZ6Urwn0xcJuZPQFcDFQ6507pkhGJJ845dlfUsftwLSXHjrP/6HFe31rOxoNVfO7SMdw5\nfxJpIR0pLImr03A3s8eBeUCBmZUA/wqkATjnfgosBa4FtgN1wE3dVazI6RysPM6afcdYv7+K9Qcq\n2XCgCgMuGJ7LlGG5TBmaS1l1A2/urGD5zgrKqxtO3DctZIwYkMlPPz2D+VN0PIAkvk7D3Tm3sJPb\nHfClLqtIJAbNLWE2H6pm9b5jFO8+QvHuo+w/dhyAUIoxfmA2l40vxDnH2v2VvLy57ETf9MCcDOaM\ny2fW2HwmDMpmWP9MBuZkkJKizmNJHl3RLSPSKeccZdUNbCutoW96iGkj+hM6gzB1zvHOvmO8sP4Q\nq/YeZd3+Suqb/FEsA3MyuGj0AP7u0jHMGJnHxME59EkLvev+tQ3NbDpYRV5WOmMLst61s1MkGSnc\npcuEw44tpdWs319JeU0D5dUNHK5ppORoHdvLaqiubz7RdkBWOldOGsjV5w5i8tB+9M9MIzsj9UTo\nOudoaA5TWlXPkrUHeXpVCTvLa0kPpXDe0H4snDmSaSP6M31EHiMG9O00rLMyUikaPaBbn79IPFG4\ny1k7WtvI5kM+zN/adYQVu49QebzpxO3ZGakUZKczOLcPH542jPGDsjmnMJuK2kb+uKmUFzccYtHK\nkhPtU1OM3L5pNLWEqW1soSV88hi/mWMG8PnLxnLt+UPI6ZPWo89TJBEp3OW0nHO8saOC7WU1lFXX\nU1bVwKGqeraV1nCoqv5Eu1H5mVwzeRAXj8ln+sj+DMntS9/0UIfL/dDUoTS1hFm15yh7jtRRWdfE\n0bpGjh1vIj2UQlZGiKyMVHL6pHH5+EJG5mf2xNMVSRoKd2lXOOx4cWMpP3p5GxsPVgF+R2VBdjoD\nc/owZ1w+k4bkMGlwPyYNyWFgzpkfD54WSuHisflcPDa/q8sX6fUU7vIuFTUNvLalnJ//eSebD1Uz\npiCL+z42lXkTCxmQma4jSkQShMK9F3POUV7TwKaD1by9q4LXtx5m/YFKnIOxBVnc/4mpfOiCoaTq\nZB6RhKNw72Wq6ptYVFzCy5tL2XywmoraRsB3ucwY2Z87rp7AZRMKOX9YrrbSRRKYwr2X2FZazaNv\n7uaZVfupa2xh0uAcrj530Il+88nD+tFPR6GIJA2FexKraWhm6bqDLCou4e3dR0gPpfChqUP57JzR\nnD88N+jyRKQbKdyT0Payan76p50sXXeQusYWxhZkcef8iXyiaAT52RlBlyciPUDhnkT2Hanjhy9v\n45lVJfRNC3H9tKF89MLhzBiZp9PtRXoZhXsSKKuu539f3cFv3tqDmXHz3DF88YpzzvqXXkQk8Snc\nE9ixukZ++qedPPrGbhpbwny8aDhfvmo8Q3L7Bl2aiARM4Z6Aahuaefgvu/j56zupaWzm+qlDuf3q\nCYwuyAq6NBGJEwr3BNLUEuaJt/fyw5e3cbimkWsmD+KO901k4uCcoEsTkTijcE8AzjmeX3+I/3xh\nM7sr6pg5ZgAP3jCJGSPzgi5NROKUwj3Ord9fybef28jbu48wYVA2D99YxJWTBuroFxE5LYV7nCqv\nbuDeZZv53coS8jLT+X8fmcKCi0ae0a8XiUjvpXCPQ1sOVfPZR97mcE0Df3fJGG67cjy5fTU0gIjE\nTuEeZ5bvrOBzvyqmb1qI339xLlOGaZgAETlzCvc4smTtAe54cg0jBvTl0ZtnMjxPvz4kImdH4R4H\nnHP87PWdfP+FzVw4Mo+Hbiyif6bOLhWRs6dwD1hdYzN3LlrLkrUH+eD5Q/jBx6fSJ63j3x4VEYmF\nwj1A+47U8blfFbOltJo750/kC5eP0yGOItIlFO4BeXvXEW59rJhw2PHIZy9i3sSBQZckIklE4R6A\nP24s5Uu/XcWwvL784saLNCaMiHQ5hXsPe3plCXc+vZYpQ/vxyE0zNSyviHQLhXsPevgvu/jOko3M\nPSefn32miOwMvfwi0j2ULj3kyRV7+c6SjXxgymD+e8E0MlJ1RIyIdB+Few9Yvfco//KHDVw6voAf\nL5xOaigl6JJEJMkpZbpZeXUDX/j1Kgb2y+BHCxTsItIztOXejZpawnzpt6s4WtfI01+YQ552nopI\nD1G4d6PvLt3E27uOcP8npmoAMBHpUQr3btDUEuZbz23g18v38tk5o/nI9OFBlyQivUxMHcBmNt/M\ntpjZdjO7q53bR5rZq2a22szWmtm1XV9qYjhW18iNv3ibXy/fy+cvH8u//M15QZckIr1Qp1vuZhYC\nHgDeB5QAK8xssXNuY1Szfwaecs79xMzOA5YCo7uh3ri2vayGv3t0BQeO1XPfx6by0Qu1xS4iwYil\nW2YmsN05txPAzJ4Argeiw90B/SKXc4EDXVlkIthyqJpPPPgmqSnG47dezIWjBgRdkoj0YrGE+zBg\nX9R8CXBxmzb/BrxoZv8AZAFXt7cgM7sVuBVg5MiRZ1pr3NpTUctnHn6LjNQUnvr8bEbla6wYEQlW\nLH3u7Y1B69rMLwR+6ZwbDlwLPGZmpyzbOfegc67IOVdUWFh45tXGoUOV9Xz64bdoagnz61suVrCL\nSFyIJdxLgBFR88M5tdvlFuApAOfcm0AfoKArCoxnR2ob+czDb3G0tolHb57J+EE5QZckIgLEFu4r\ngPFmNsbM0oEFwOI2bfYCVwGY2bn4cC/vykLjzfHGFm765Qr2HKnj5zcUccHw/kGXJCJyQqfh7pxr\nBm4DlgGb8EfFbDCzb5vZdZFmXwU+Z2ZrgMeBzzrn2nbdJI1w2HHHU++wtuQYP144ndnj8oMuSUTk\nXWI6ick5txR/eGP0dfdEXd4IzO3a0uLXvS9u4fn1h/jnD57LNZMHB12OiMgpNIrVGXpqxT5+8toO\nPnnxSG65ZEzQ5YiItEvhfgbe3FHBN36/jkvHF/Ct6ybrx6xFJG4p3GN0pLaRLz+xmlH5mfzPJ2eQ\npqF7RSSOaeCwGDjnuOvptVTWNfGrm2eS2zct6JJERE5Lm58xeKp4Hy9uLOXO+RM5d0i/zu8gIhIw\nhXsndh2u5VvP+R+1vnmudqCKSGJQuJ9GU0uY2598h7RQCvd9bCopKdqBKiKJQX3up/H95zezZt8x\nHvjkDIbk9g26HBGRmGnLvQOPvbmbh/6yixtnj+KDFwwJuhwRkTOicG/HK5tL+dfFG7j63IHc86HJ\nQZcjInLGFO5trN9fyW2/Xc3kobn8aOF0QupnF5EEpHCPUlpVz82/XEFeZjoP31hEZrp2SYhIYlJ6\nRfmvF7dyrK6J5/7hEgb26xN0OSIiZ01b7hG7DteyaFUJn5o1komD9aMbIpLYFO4RP/zjVtJDKXxh\n3rigSxERec8U7sDW0mqeXXOAG+eMZmCOumMkiTkHjbVBVyE9QH3uwH//cStZ6al8/rKxQZciiaKh\nGprqoW9/CLUzkJxzUHcEqg9A1QGoPgRNddDc4KeWRjADDCzl5JSSAhbyy2iuh6bjvn1TLdRX+cdt\nqIaMbBgwDvLH+b8pKSdvb6wBF/bLw/xjVWyDss1QvgUaq6FvHuSN9lO/YZCe5ae0TAg3Q00Z1Jb7\nqbkBUkKRGkPgWvwyW5oh3ASpGZCe7e+b1nqyn/OvAQahVAil+8kMwmG/DBeGcIu/HG6db44su8lP\nKSF/v9Q+fjnNDX7l1FjrX5tolgLpmb6O9GzfvqV1eY2RNuafQ0rI13fisd3JxwqlQUqqv771vi7s\nrw9l+OebmhGpK8NfZxZ5XxugufHkcwg3RT2fqOc1+4sw6YPd8ck8odeH+/r9lSxdd4ivXDWevKz0\noMuReNFYCxufheqDUFvhQ66m1M9XHfQB2So9G/r09wHbdDwy1flA6IhFvjSfrg1EAqSvD80+/SAj\nx0/Hj8LaJ6GhKrbnk1kAA8+FqQug3xCo3A9Hd8PBtbB1ma83WkoqZBVCVoEP1uggTkmFlLSTodvc\nAHX7/Aqo6Th+hRVZceHeHdjhlsiKIhRZkaWcDFsLnVwRpKT5y+FwJDAb/P3T+kBaViTE+558HcHX\nVl/lV6SNNT7YQ2knA9tS3r0yiQ761tuiQzklLXL/NN+upfFkLa0r6OYGv7zW9zSUAanpJ1dmobST\nr9WJWtIjr0336vXhfv9LW8ntm8Ytl2pQMImo2AFPfhrKNvr59GwfclmFUDgJxl0JOUP8FmL9MTh+\nzP91YR84aZk+9LIKoN9QyBkKOYP91nbrll9K6OTjOefvGx2g4JcR3a4t56D2MBzZ6YOqNfjTsyOh\nF1muhfyK4XTCYR/wTXU+vFtXVtK51i3/UHzFaXxV08NW7jnKy5vL+No1E+nXR2O0C7DtJXj6Fh+O\nn/wdjLk0qquhm7RuQRJqv4vndPfLLvTTe5WS4lc+GdnvfVm9zelWwAHqtatm5xzfW7qJgTkZ3DR3\ndFcsEHb/1fdpdqSlKdIPKXHHOXj9PvjNxyB3JNz6Gkx4f/cHu0g36bVb7i9tLKV4z1G++5Hz3/uZ\nqAdWw7Jvwp6/+vmxV8CsL8A57/Nf1zcvgXWLYPeffX9bZoH/yt43L9J/2drfmOa/Umdk+51b2YOg\nYDwUTPBf7bvya3I4HNvyju6GvcthWBEUnNN1j3+2Gusifa0x9lm2NPv3Z8cr/v0ZNRcu+6d3b221\nNMOSr8DqX8OUj8J1P/Z9uiIJrFeGe3NLmO+/sJlxhVl8vGj4qQ1qD8OuP8HO1/yOtSu+6Y9KaKty\nP7zyHVjzuA/sa++D+kpY8RD89uP+KISaMr9zZsBYmP2lyPIjO+jqj/mdTa1HCrQ0+sdrqPE77KJ3\ntqVlwrnXwQf+w68U2mqo9jvZTuzMczBk6qlfGSt2wPNfh+0vAeZXLqE038eaP87XOWAsVJb4QDyy\nw98vJRVmfh4uv9MfIdKRo7th8//51/D4UT+1NJ7sD87I8a/LkKkwaErnIeoclG6ATc/5qWwDZORC\n3ih/pMeAMX7lVzDBrwgxH+b7V/ppz1/9e4L55/Xad2Hvm/C3D0NWvl9ZLLoJtr4Al38d5t0d+4pD\nJI6ZC6iboKioyBUXFwfy2L99ay/f+P06fvaZC7lm8uCTN+z+K7zwdTi0zs9n5ALOd6dcdQ9c/Hkf\nljXl8Jf7fYjjYNYX4dKvntxp1dLkj7RY+5QPnCl/C0Onn1loOOdXAIe3weGtcHANrH7Mb81/5Kcw\n5jLf7sgu+PMP/Aom3PzuZeSOgAtvhOk3QJ9cX/Nf7vffHi680W8Bh5t9vXUVPviP7PCX07Jg9CV+\n5+Hwi2DVo7DqV5A5AK74hn9O0SuZqoPw+r2+TbjJfxPpm+fbhzL8yqqh2h/NEG7y97EUH8pDp8Ow\nC/00aApUlcCeN2HvG7Drz3BsD2AwcrZ/3seP+JVI69R6mFtb+efAiFlwzpUwZp4P85WPwtKvQfZA\nuO5H8Or3oGQFfPAHcNEtsb8/IgExs5XOuaJO2/W2cK9rbObye19j5IBMFv39bKw1cNf+Dp79IuQO\nh2mf8l0rQ6b6gF1yu9+yG3GxD5i3fw7Nx2HqQr+1lzeqZ4rfvwqe+ZwP4Vlf8EG55nG/VT3jBhg6\n7eTRGg3Vvpth15/87ZkFUHPIdzu8/9/94XAdqa/0h9+ltjk09OAaeOHuk91PA8bC0Bl+a7x15TLj\nBrjkH/2Kpb2VmXP+uO+DayLTO34Lu7bc324pJ7+x9M3zr/eEa2DitT6Q22pp9uHfuhJsaYRhM/wK\no71vOK2v41M3QOU+v6L724fgvOtP/9qLxAmFewd+/PI2fvDSVhb9/WyKRg/wYfPnH/julVGXwIJf\nnxoKzvljip+/0wff5I/AvG9A4YQer5/GWnjxX6D4YX+o3IU3wdyvdBzWh7fDykd818Ylt8PYee/t\n8Z3z4b7vLR+SB1b7Y78vWADzvu67Ss5mmZUlsL/YH3edOwxGzvGHHXbX4Xi1FfCn78PkD8OoOd3z\nGCLdQOHejoqaBi6/9zVmj8vn5zcU+TMMn/+a70o4/+Nw/f/4Y5A7UnfEbxH31Jb66ZRt9iuhnEFB\nV+K7dc7kED4ROWuxhnuv2qH641e2U9fYzNfnT4Rdr8Nzt/s+5kv/Ca785877xDMH+CkeDJwUdAUn\nKdhF4k6vCfc9FbX85q093DS9H+e88XV45zeQNwY+8wcYd0XQ5YmIdKleE+73vbiVvJRavrHnH+H4\nYbjkDn9Yn05SEZEk1CvCfW3JMZ5bc4DHx71BaP8huHkZjJwVdFkiIt0m6YcfcM7xH89vZnTfemaV\nPwnnfVjBLiJJL+nD/fVth3ljRwU/GvlnrLEO5t0VdEkiIt0u6cP9By9u4fy8Bs4/8CSc/1E/prWI\nSJKLKdzNbL6ZbTGz7WbW7qavmX3czDaa2QYz+23Xlnl2tpfVsLakku8NfBVrrvdnk4qI9AKd7lA1\nsxDwAPA+oARYYWaLnXMbo9qMB+4G5jrnjppZO+eJ97zn1x2kkKNMLnkKLvhEZGApEZHkF8uW+0xg\nu3Nup3OuEXgCaDsQx+eAB5xzRwGcc2VdW+bZ+b91B/m3vBexcJM/7FFEpJeIJdyHAfui5ksi10Wb\nAEwws7+a2XIzm9/egszsVjMrNrPi8vLys6s4RjvLa2gp3cT8+v+D6Z/yg1yJiPQSsYR7e+fktx2Q\nJhUYD8wDFgIPmdkpg3475x50zhU554oKC7vgp8FO4/l1B/he2kNYRg5c9a/d+lgiIvEmlnAvAUZE\nzQ8HDrTT5lnnXJNzbhewBR/2gXHFj1CUspWU+d/1v3okItKLxBLuK4DxZjbGzNKBBcDiNm3+AFwB\nYGYF+G6anV1Z6JnYt3sHN9Q+wv68mX7MdRGRXqbTcHfONQO3AcuATcBTzrkNZvZtM7su0mwZUGFm\nG4FXga855yq6q+jONC75J9JpJnT9D/WTaSLSK8U0toxzbimwtM1190RddsAdkSlYm5Yw7vAr/Cr7\nRm4YfV7Q1YiIBCK5Bg4Lh2ladg87w8NpuOiLQVcjIhKY5Bp+YPsfSTu2gweaP8z8C0YGXY2ISGCS\nK9yX/y8VKfnsGXw1IwZkBl2NiEhgkifcyzbBzlf5RePVXDKxgx+LFhHpJZIn3Jf/hJZQBr9pvpK5\n43Rcu4j0bskR7rUVsPZJ3smbz/HUXGaMygu6IhGRQCVHuK98BJrr+VnD+ykanUeftFDQFYmIBCrx\nw725EVY8ROPoebxYnsccdcmIiCRBuG98FqoPsnqIH2Zgzrj8gAsSEQle4of7hmcgdyTP1kwiJyOV\n84flBl2RiEjgEjvcw2HYuxzGXMZfdhzl4rH5pIYS+ymJiHSFxE7Cw1vh+BGO5M9g75E65p6jLhkR\nEUj0cN/7JgDLWyYCMPcc7UwVEYFkCPesQpYdzKQgO4PxA7ODrkhEJC4kfLi7kbN5Y+cR5ozLxzR2\nu4gIkMjhXrkfju2lPG8G5dUN6m8XEYmSuOF+or99AoBOXhIRiZLY4Z6ezeuVgxjcr4+G+BURiZLA\n4b4chl/EgepmhvbvE3Q1IiJxJTHD/fgxKN0AI2dTWlXPoH4KdxGRaIkZ7vveBhyMmk1ZVYPCXUSk\njcQM971vQEoqtYXTqG5oVriLiLSRoOG+HIZMo/S4L39Qv4yACxIRiS+JF+5N9bB/JYycRWlVA4C2\n3EVE2ki8cD+wGloaYdQcyqrrAYW7iEhbiRfue9/wf0fM4lBla7irW0ZEJFpq0AWcsakLoWACZOVT\nWlVKZnqI7IzEexoiIt0p8VKx31A/AaXV9Qzu10cDhomItJF43TJRSivrGaguGRGRUyR2uFfr7FQR\nkfYkbLg75yitamCwwl1E5BQJG+6Vx5tobA4zUOEuInKKhA33Q1U6DFJEpCMJG+46O1VEpGMJHO5+\ny1197iIip4op3M1svpltMbPtZnbXadp91MycmRV1XYntK42cnVqYo24ZEZG2Og13MwsBDwAfAM4D\nFprZee20ywG+DLzV1UW2p7S6nv6ZafRJC/XEw4mIJJRYttxnAtudczudc43AE8D17bT7DvCfQH0X\n1tchHQYpItKxWMJ9GLAvar4kct0JZjYdGOGcW3K6BZnZrWZWbGbF5eXlZ1xstLKqeh0GKSLSgVjC\nvb2BW9yJG81SgPuBr3a2IOfcg865IudcUWFhYexVtuNQVT2D1N8uItKuWMK9BBgRNT8cOBA1nwNM\nAV4zs93ALGBxd+5UbQk7yqsbGJyrLXcRkfbEEu4rgPFmNsbM0oEFwOLWG51zlc65AufcaOfcaGA5\ncJ1zrrhbKgYqahoIO9QtIyLSgU7D3TnXDNwGLAM2AU855zaY2bfN7LruLrA9J85OVbeMiEi7YhrP\n3Tm3FFja5rp7Omg7772XdXo6O1VE5PQS8gzVE2enqs9dRKRdCRnuZVX1pBjkZ6UHXYqISFxKyHA/\nVFVPQXYGqaGELF9EpNslZDqWVukwSBGR00nQcK9nYI7CXUSkIwkb7vqRDhGRjiVcuDc0t3C0rkmD\nhomInEbChXuZjnEXEelUwoV76zHuA9UtIyLSoQQMd225i4h0JgHDXb+dKiLSmYQL90lDcvjsnNH0\nz0wLuhQRkbgV08Bh8WTOuALmjCsIugwRkbiWcFvuIiLSOYW7iEgSUriLiCQhhbuISBJSuIuIJCGF\nu4hIElK4i4gkIYW7iEgSUriLiCQhhbuISBJSuIuIJCGFu4hIElK4i4gkIYW7iEgSUriLiCQhhbuI\nSBJSuIuIJCGFu4hIElK4i4gkIYW7iEgSUriLiCShmMLdzOab2RYz225md7Vz+x1mttHM1prZy2Y2\nqutLFRGRWHUa7mYWAh4APgCcByw0s/PaNFsNFDnnLgAWAf/Z1YWKiEjsYtlynwlsd87tdM41Ak8A\n10c3cM696pyri8wuB4Z3bZkiInImYgn3YcC+qPmSyHUduQV4vr0bzOxWMys2s+Ly8vLYqxQRkTMS\nS7hbO9e5dhuafRooAu5t73bn3IPOuSLnXFFhYWHsVYqIyBlJjaFNCTAian44cKBtIzO7GvgmcLlz\nrqFryhMRkbMRy5b7CmC8mY0xs3RgAbA4uoGZTQd+BlznnCvr+jJFRORMdBruzrlm4DZgGbAJeMo5\nt8HMvm1m10Wa3QtkA78zs3fMbHEHixMRkR4QS7cMzrmlwNI2190TdfnqLq5LRETeA52hKiKShBTu\nIiJJSOEuIpKEFO4iIklI4S4ikoQU7iIiSUjhLiKShBTuIiJJSOEuIpKEFO4iIklI4S4ikoQU7iIi\nSUjhLiKShBTuIiJJSOEuIpKEFO4iIklI4S4ikoQU7iIiSUjhLiKShBTuIiJJSOEuIpKEFO4iIklI\n4S4ikoQU7iIiSUjhLiKShBTuIiJJSOEuIpKEFO4iIklI4S4ikoQU7iIiSUjhLiKShBTuIiJJSOEu\nIpKEFO4iIklI4S4ikoRiCnczm29mW8xsu5nd1c7tGWb2ZOT2t8xsdFcXKiIises03M0sBDwAfAA4\nD1hoZue1aXYLcNQ5dw5wP/D9ri5URERiF8uW+0xgu3Nup3OuEXgCuL5Nm+uBRyOXFwFXmZl1XZki\nInImUmNoMwzYFzVfAlzcURvnXLOZVQL5wOHoRmZ2K3BrZLbGzLacTdFAQdtlx5F4rS1e64L4rS1e\n64L4rS1e64LkqW1ULI1iCff2tsDdWbTBOfcg8GAMj3n6gsyKnXNF73U53SFea4vXuiB+a4vXuiB+\na4vXuqD31RZLt0wJMCJqfjhwoKM2ZpYK5AJHuqJAERE5c7GE+wpgvJmNMbN0YAGwuE2bxcCNkcsf\nBV5xzp2y5S4iIj2j026ZSB/6bcAyIAT8wjm3wcy+DRQ75xYDDwOPmdl2/Bb7gu4smi7o2ulG8Vpb\nvNYF8VtbvNYF8VtbvNYFvaw20wa2iEjy0RmqIiJJSOEuIpKEEi7cOxsKoYdr+YWZlZnZ+qjrBpjZ\nS2a2LfI3L4C6RpjZq2a2ycw2mNlX4qE2M+tjZm+b2ZpIXd+KXD8mMmzFtsgwFuk9WVebGkNmttrM\nlsRLbWa228zWmdk7ZlYcuS7wz1mkjv5mtsjMNkc+b7ODrs3MJkZeq9apysxuD7quqPr+MfL5X29m\nj0f+L7r8c5ZQ4R7jUAg96ZfA/DbX3QW87JwbD7wcme9pzcBXnXPnArOAL0Vep6BrawCudM5NBaYB\n881sFn64ivsjdR3FD2cRlK8Am6Lm46W2K5xz06KOhQ76vWz1Q+AF59wkYCr+tQu0NufclshrNQ24\nEKgDfh90XQBmNgz4MlDknJuCP0hlAd3xOXPOJcwEzAaWRc3fDdwdcE2jgfVR81uAIZHLQ4AtcfC6\nPQu8L55qAzKBVfiznQ8Dqe3P8xtXAAACw0lEQVS9xz1c03D8P/2VwBL8yXmB1wbsBgraXBf4ewn0\nA3YROTAjnmqLquX9wF/jpS5Ons0/AH+04hLgmu74nCXUljvtD4UwLKBaOjLIOXcQIPJ3YJDFREbo\nnA68RRzUFun2eAcoA14CdgDHnHPNkSZBvqf/DdwJhCPz+cRHbQ540cxWRobwgDh4L4GxQDnwSKQr\n6yEzy4qT2lotAB6PXA68LufcfuA+YC9wEKgEVtINn7NEC/eYhjkQz8yygaeB251zVUHXA+Cca3H+\n6/Jw/KB057bXrGerAjP7G6DMObcy+up2mgbxeZvrnJuB7478kpldFkAN7UkFZgA/cc5NB2oJrnvo\nFJF+6+uA3wVdS6tIP//1wBhgKJCFf1/bes+fs0QL91iGQghaqZkNAYj8LQuiCDNLwwf7b5xzz8RT\nbQDOuWPAa/h9Av0jw1ZAcO/pXOA6M9uNH/n0SvyWfOC1OecORP6W4fuOZxIf72UJUOKceysyvwgf\n9vFQG/jQXOWcK43Mx0NdVwO7nHPlzrkm4BlgDt3wOUu0cI9lKISgRQ/FcCO+v7tHmZnhzxre5Jz7\nr3ipzcwKzax/5HJf/Ad9E/AqftiKQOoCcM7d7Zwb7pwbjf9cveKc+1TQtZlZlpnltF7G9yGvJw4+\nZ865Q8A+M5sYueoqYGM81BaxkJNdMhAfde0FZplZZuT/tPU16/rPWVA7Ot7DDolrga34vtpvBlzL\n4/h+syb8Vswt+H7al4Ftkb8DAqjrEvzXurXAO5Hp2qBrAy4AVkfqWg/cE7l+LPA2sB3/FToj4Pd1\nHrAkHmqLPP6ayLSh9TMf9HsZVd80oDjynv4ByIuH2vA77CuA3KjrAq8rUse3gM2R/4HHgIzu+Jxp\n+AERkSSUaN0yIiISA4W7iEgSUriLiCQhhbuISBJSuIuIJCGFu4hIElK4i4gkof8Pp1Ko92a3+YMA\nAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fc9285f9908>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "epochs = 80\n",
    "\n",
    "a, b = [], []\n",
    "for epoch in range(epochs):\n",
    "    if epoch  == 40:\n",
    "        trainer.set_learning_rate(0.005)\n",
    "    if epoch  == 60:\n",
    "        trainer.set_learning_rate(0.001)\n",
    "    train_loss = 0.\n",
    "    train_acc = 0.\n",
    "    start = time()\n",
    "    for data, label in train_data:\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            l = loss(output, label)\n",
    "        l.backward()\n",
    "        trainer.step(batch_size)\n",
    "        train_loss = train_loss + nd.mean(l).asscalar()\n",
    "        train_acc = train_acc + accuracy(output, label)\n",
    "    test_acc = evaluate_accuracy(test_data, net, ctx)\n",
    "    \n",
    "    if epoch%10 == 0:\n",
    "        print(epoch, 'loss:%.4f tracc:%.4f teacc:%.4f time:%.3f'%(\n",
    "            train_loss/len(train_data), train_acc/len(train_data), test_acc, time()-start)) \n",
    "    a.append(train_acc/len(train_data))\n",
    "    b.append(test_acc)\n",
    "\n",
    "print('tracc:%f teacc:%f'%(train_acc/len(train_data), test_acc))\n",
    "plt.plot(np.arange(epochs), a, np.arange(epochs), b)\n",
    "plt.ylim(0,1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
